import matplotlib.pyplot as plt

fig, axs = plt.subplots(1,8, figsize=(17,6))
for i in range(8):
    axs[i].imshow(X_train[i])
    axs[i].axis('off')
plt.show()
